{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can download the `requirements.txt` for this course from the workspace of this lab. `File --> Open...`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# L2-A - Linear Quantization I: Quantize and De-quantize a Tensor\n",
    "\n",
    "In this lesson, you will learn the fundamentals of linear quantization."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The libraries are already installed in the classroom.  If you're running this notebook on your own machine, you can install the following:\n",
    "\n",
    "```Python\n",
    "!pip install torch==2.1.1\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quantization with Random `Scale` and `Zero Point`\n",
    "\n",
    "- Implement Linear Quantization for when the \"scale\" and the \"zero point\" are known/randomly selected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 233
   },
   "outputs": [],
   "source": [
    "def linear_q_with_scale_and_zero_point(\n",
    "    tensor, scale, zero_point, dtype = torch.int8):\n",
    "\n",
    "    scaled_and_shifted_tensor = tensor / scale + zero_point\n",
    "\n",
    "    rounded_tensor = torch.round(scaled_and_shifted_tensor)\n",
    "\n",
    "    q_min = torch.iinfo(dtype).min\n",
    "    q_max = torch.iinfo(dtype).max\n",
    "\n",
    "    q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)\n",
    "    \n",
    "    return q_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 114,
    "id": "Oz0fFY6gMXI6"
   },
   "outputs": [],
   "source": [
    "### a dummy tensor to test the implementation\n",
    "test_tensor=torch.tensor(\n",
    "    [[191.6, -13.5, 728.6],\n",
    "     [92.14, 295.5,  -184],\n",
    "     [0,     684.6, 245.5]]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "### these are random values for \"scale\" and \"zero_point\"\n",
    "### to test the implementation\n",
    "scale = 3.5\n",
    "zero_point = -70"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "quantized_tensor = linear_q_with_scale_and_zero_point(\n",
    "    test_tensor, scale, zero_point)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "quantized_tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dequantization with Random `Scale` and `Zero Point`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Now, Dequantize the tensor to see how precise the quantization is."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "dequantized_tensor = scale * (quantized_tensor.float() - zero_point)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "# this was the original tensor\n",
    "# [[191.6, -13.5, 728.6],\n",
    "#  [92.14, 295.5,  -184],\n",
    "#  [0,     684.6, 245.5]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "dequantized_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "### without casting to float\n",
    "scale * (quantized_tensor - zero_point)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "def linear_dequantization(quantized_tensor, scale, zero_point):\n",
    "    return scale * (quantized_tensor.float() - zero_point)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Calculate `dequantized_tensor` using the function `linear_dequantization`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "dequantized_tensor = linear_dequantization(\n",
    "    quantized_tensor, scale, zero_point)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Print the results of the `dequantized_tensor`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "dequantized_tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantization Error"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Load the `plot_quantization_errors` from the helper file.\n",
    "- To access the `helper.py` file, you can click `File --> Open...`, on the top left."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from helper import plot_quantization_errors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Plot the quantization results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 425
    },
    "height": 46,
    "id": "5Jr5S5ZBNhkj",
    "outputId": "2edb875f-db33-48ea-aec5-1bdfb7e1f244"
   },
   "outputs": [],
   "source": [
    "plot_quantization_errors(test_tensor, quantized_tensor,\n",
    "                         dequantized_tensor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note:** For the plot above, `Quantization Error Tensor = abs(Original Tensor - Dequantized Tensor)`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Calculate an \"overall\" quantization error by using [Mean Squared Error](https://en.wikipedia.org/wiki/Mean_squared_error) technique."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "dequantized_tensor - test_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "(dequantized_tensor - test_tensor).square()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "(dequantized_tensor - test_tensor).square().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "kd9Q2SD0MJGw",
    "E1palk4uRQX9",
    "oAexFpXiX1PW",
    "ChnEqFPYMn3p",
    "MFx2m7RmzRd5",
    "LbPjb9OOi0Xp",
    "6WopWDYWQr7X",
    "LYfTqh_VMTzT",
    "4YZP-9XTNkur",
    "dpKLMdCYvT_W",
    "2hoC5tcJznoI",
    "S68UGldKRnJc",
    "Y4Qfalu9vtHv",
    "WJN9IfVLTFNd",
    "qC0X9ux6JEmi",
    "AkwpMs-C5ccj",
    "EDpd5Te632KY",
    "JA1-rcLz4t4D",
    "kROAEGfdDsau",
    "oo4BCLpsDw3t",
    "h2gK-eALFc8U",
    "yDin7Rm6Dzqu",
    "X6J9ZiyHWzHa",
    "qUw1gQUu5yIe",
    "vGll7vBT6BGI",
    "Cl3AUuDuAH5w",
    "8NS1TnQt6E6v",
    "vcRy85lACotg"
   ],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
